{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "90c91940-722b-4f5a-bf9f-3549f0626643",
   "metadata": {},
   "source": [
    "# TorchMetrics -- How do we use it, and what's the difference between .update() and .forward()?"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d439ca76-b505-40f5-a105-0919303af705",
   "metadata": {},
   "source": [
    "[TorchMetrics](https://torchmetrics.readthedocs.io/en/latest/pages/overview.html) is a really nice and convenient library that lets us compute the performance of models in an iterative fashion. It's designed with PyTorch (and PyTorch Lightning) in mind, but it is a general-purpose library compatible with other libraries and workflows.\n",
    "\n",
    "This iterative computation is useful if we want to track a model during iterative training or evaluation on minibatches (and optionally across on multiple GPUs). In deep learning, that's essentially *all the time*.\n",
    "\n",
    "Here, its object-oriented programming-based implementation helps keep track of things. \n",
    "\n",
    "**However, when using TorchMetrics, one common question is whether we should use `.update()` or `.forward()`? (And that's also a question I certainly had when I started using it.)**\n",
    "\n",
    "While the [documentation explains](https://torchmetrics.readthedocs.io/en/latest/pages/implement.html#internal-implementation-details) what's going on when calling `.forward()`, it may make sense to augment the following explanation with a hands-on example in this notebook."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f2f2a081-eac7-4e19-aa5c-2a88cb7d1318",
   "metadata": {},
   "source": [
    "## Computing the accuracy manually for comparison"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "048008b1-2ca5-4006-83a6-de6a5e8b3b32",
   "metadata": {},
   "source": [
    "While TorchMetrics allows us to compute much fancier things (e.g., see the confusion matrix at the bottom of my other notebook [here](https://github.com/rasbt/deeplearning-models/blob/master/pytorch-lightning_ipynb/mlp/mlp-basic.ipynb)), let's stick to the regular classification accuracy since it is more minimalist and intuitive to compute."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7589dd53-7c8d-4cf2-90a1-25c9c218fb94",
   "metadata": {},
   "source": [
    "Also, before we dive into TorchMetrics, let's compute the accuracy ***manually*** to ensure we have a baseline for understanding how TorchMetrics works.\n",
    "\n",
    "Suppose we are training a model for an epoch that consists of 10 minibatches. In the following code, we will simulate this via the outer for-loop (`for i in range(10):`).\n",
    "\n",
    "Moreover, instead of using an actual dataset and model, let's pretend that\n",
    "\n",
    "- `y_true = torch.randint(low=0, high=2, size=(10,))` is a tensor with the ground truth labels for our minibatch. It consists ten 0's and 1's.  \n",
    "(E.g., it is similar to something like  \n",
    "`torch.tensor([0, 1, 1, 0, 1, 0, 1, 1, 0, 1])`.)\n",
    "\n",
    "- `y_pred = torch.randint(low=0, high=2, size=(10,))` are the predicted class  labels for our minibatch. It's a tensor consisting of ten 0's and 1's similar to `y_true`.\n",
    "\n",
    "Via `torch.manual_seed(123)`, we ensure that the code is reproducible and gives us exactly the same results each time we execute the following code cell:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "ef372d44-714e-46f6-b515-7638be93be37",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Overall accuracy: tensor(0.5600)\n"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "\n",
    "\n",
    "torch.manual_seed(123)\n",
    "\n",
    "all_true, all_pred = [], []\n",
    "\n",
    "for i in range(10):\n",
    "    y_true = torch.randint(low=0, high=2, size=(10,))\n",
    "    y_pred = torch.randint(low=0, high=2, size=(10,))\n",
    "    all_true.append(y_true)\n",
    "    all_pred.append(y_pred)\n",
    "\n",
    "correct_pred = (torch.cat(all_true) == torch.cat(all_pred)).float()\n",
    "    \n",
    "acc = torch.mean(correct_pred)\n",
    "print('Overall accuracy:', acc)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "26fdebbc-fcb7-4179-9c84-711a6cc3c38a",
   "metadata": {},
   "source": [
    "So, what we have done above is we collected all the true class labels `all_true`,  and `all_pred`. Then, we computed the number of correct predictions (the number of times the true and the predicted labels match) and assigned this number to `correct_pred`. Finally, we computed the average number of correct predictions, which is our accuracy.\n",
    "\n",
    "If we work with large datasets, it would be wasteful to accumulate all the labels via `all_true` and `all_pred` (and, in the worst case, we could exceed the GPU memory). A smarter way is to count the number of correct predictions and then divide that number by the total number of training examples as shown below:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "5820e86d-d529-481b-b5cf-32b0538d7ad1",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Overall accuracy: tensor(0.5600)\n"
     ]
    }
   ],
   "source": [
    "torch.manual_seed(123)\n",
    "\n",
    "num = 0\n",
    "correct = 0.\n",
    "\n",
    "for i in range(10):\n",
    "    y_true = torch.randint(low=0, high=2, size=(10,))\n",
    "    y_pred = torch.randint(low=0, high=2, size=(10,))\n",
    "    correct += (y_true == y_pred).float().sum()\n",
    "    num += y_true.numel()\n",
    "    \n",
    "acc = correct / num\n",
    "print('Overall accuracy:', acc)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6bf22b7b-fe9c-47da-902c-a9b645b946b3",
   "metadata": {},
   "source": [
    "## Using TorchMetrics"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1261eec0-fe5c-4345-9d51-a9124c4cb2ea",
   "metadata": {},
   "source": [
    "So, TorchMetrics allows us to do what we have done in the previous section; that is, iteratively computing a metric. \n",
    "\n",
    "The general steps are as follows:\n",
    "\n",
    "1. We initialize a metric we want to compute (here: accuracy). \n",
    "2. We call `.update()` during the training loop. \n",
    "3. Finally, we call `.compute()` to get the final accuracy value when we are done.\n",
    "\n",
    "Let's take a look at what this looks like in code:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "a4a4cb73-61ba-4852-904c-426ab16539c8",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Batch accuracy: None\n",
      "Batch accuracy: None\n",
      "Batch accuracy: None\n",
      "Batch accuracy: None\n",
      "Batch accuracy: None\n",
      "Batch accuracy: None\n",
      "Batch accuracy: None\n",
      "Batch accuracy: None\n",
      "Batch accuracy: None\n",
      "Batch accuracy: None\n",
      "Overall accuracy: tensor(0.5600)\n"
     ]
    }
   ],
   "source": [
    "from torchmetrics.classification import Accuracy\n",
    "\n",
    "\n",
    "train_acc = Accuracy()\n",
    "torch.manual_seed(123)\n",
    "\n",
    "for i in range(10):\n",
    "    \n",
    "    y_true = torch.randint(low=0, high=2, size=(10,))\n",
    "    y_pred = torch.randint(low=0, high=2, size=(10,))\n",
    "    \n",
    "    abc = train_acc.update(y_true, y_pred)\n",
    "    print('Batch accuracy:', abc)\n",
    "    \n",
    "print('Overall accuracy:', train_acc.compute())"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e900c76d-a2b6-4764-9075-f86c4631bda3",
   "metadata": {},
   "source": [
    "Notice that the overall accuracy is the same that we got from computing it manually in the previous section. For reference, we also printed the accuracy for each batch; however, there is nothing interesting here because it's always `None`. The following code example will make it clear why we did that.\n",
    "\n",
    "So, in the following code example, we make a small modification to the training loop. Now, we are calling `train_acc.forward()` (or, to be more precise, the equivalent shortcut `train_acc()`) instead of `train_acc.update()`. The `.forward()` call does a bunch of things under the hood, which we will talk about later. For now, let's just inspect the results:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "53a6daab-7474-4db8-8413-a029cb9a64ae",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Batch accuracy: tensor(0.7000)\n",
      "Batch accuracy: tensor(0.7000)\n",
      "Batch accuracy: tensor(0.5000)\n",
      "Batch accuracy: tensor(0.6000)\n",
      "Batch accuracy: tensor(0.4000)\n",
      "Batch accuracy: tensor(0.4000)\n",
      "Batch accuracy: tensor(0.6000)\n",
      "Batch accuracy: tensor(0.6000)\n",
      "Batch accuracy: tensor(0.5000)\n",
      "Batch accuracy: tensor(0.6000)\n",
      "Overall accuracy: tensor(0.5600)\n"
     ]
    }
   ],
   "source": [
    "train_acc = Accuracy()\n",
    "\n",
    "torch.manual_seed(123)\n",
    "for i in range(10):\n",
    "    \n",
    "    y_true = torch.randint(low=0, high=2, size=(10,))\n",
    "    y_pred = torch.randint(low=0, high=2, size=(10,))\n",
    "    \n",
    "    # the following two lines are equivalent:\n",
    "    # abc = train_acc.forward(y_true, y_pred) \n",
    "    abc = train_acc(y_true, y_pred) \n",
    "    \n",
    "    print('Batch accuracy:', abc)\n",
    "    \n",
    "print('Overall accuracy:', train_acc.compute())"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c9f6b80b-285c-48fb-be4a-529cffed61b9",
   "metadata": {},
   "source": [
    "As we can see, the overall accuracy is the same as before (as we would expect :)). However, we now also have the intermediate results: the batch accuracies. The batch accuracy refers to the accuracy for the given minibatch. For reference, below is how it looks like if we compute it manually:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "83a4b954-6f31-42f5-b738-16c8584af79d",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Batch accuracy: tensor(0.7000)\n",
      "Batch accuracy: tensor(0.7000)\n",
      "Batch accuracy: tensor(0.5000)\n",
      "Batch accuracy: tensor(0.6000)\n",
      "Batch accuracy: tensor(0.4000)\n",
      "Batch accuracy: tensor(0.4000)\n",
      "Batch accuracy: tensor(0.6000)\n",
      "Batch accuracy: tensor(0.6000)\n",
      "Batch accuracy: tensor(0.5000)\n",
      "Batch accuracy: tensor(0.6000)\n",
      "Overall accuracy: tensor(0.5600)\n"
     ]
    }
   ],
   "source": [
    "torch.manual_seed(123)\n",
    "\n",
    "num = 0\n",
    "correct = 0.\n",
    "\n",
    "for i in range(10):\n",
    "    y_true = torch.randint(low=0, high=2, size=(10,))\n",
    "    y_pred = torch.randint(low=0, high=2, size=(10,))\n",
    "    \n",
    "    correct_batch = (y_true == y_pred).float().sum()\n",
    "    correct += correct_batch\n",
    "    num += y_true.numel()\n",
    "    \n",
    "    abc = correct_batch / y_true.numel()\n",
    "    print('Batch accuracy:', abc) \n",
    "    \n",
    "acc = correct / num\n",
    "print('Overall accuracy:', acc)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "df67ea81-e5fb-42fb-ace2-39c3020a7a57",
   "metadata": {},
   "source": [
    "If we are interested in the validation set or test accuracy, this intermediate result is maybe not super useful. However, it can be handy for tracking the training set accuracy during training. Also, it is useful for things like the loss function. This way, we can plot both the intermediate loss per minibatch and the average loss per epoch with only one pass over the training set."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e196d36e-35e7-42f0-82e8-5f2d092618cf",
   "metadata": {},
   "source": [
    "## .update() vs .forward() -- the official explanation"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ae322bcb-dee7-45dd-b119-753734ff8b20",
   "metadata": {},
   "source": [
    "So, in the previous section, we saw that both `.forward()` and `.update()` do slightly different things. The `.update()` method is somewhat simpler: it just updates the metric. In contrast, `.forward()` updates the metric, but it also lets us report the metric for each individual batch update. The `.forward()` method is essentially a more sophisticated method that uses `.update()` under the hood. \n",
    "\n",
    "Which method should we use? It depends on your use case. If we don't care about tracking or logging the intermediate results, using `.update()` should suffice. However, calling `.forward()` is usually computationally very cheap -- in the grand scheme of training deep neural networks -- so it's not harmful to default to using `.forward()`, either.\n",
    "\n",
    "If you are interested in the nitty-gritty details, have a look at the following excerpt from the official [documentation](https://torchmetrics.readthedocs.io/en/latest/pages/implement.html#internal-implementation-details):"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d36e24f5-0722-474f-9e7b-3d569e4d7a19",
   "metadata": {},
   "source": [
    "\n",
    "> The `forward()` method achieves this by combining calls to `update` and `compute` in the following way:\n",
    "\n",
    "> 1. Calls `update()` to update the global metric state (for accumulation over multiple batches)\n",
    "2. Caches the global state.\n",
    "3. Calls `reset()` to clear global metric state.\n",
    "4. Calls `update()` to update local metric state.\n",
    "5. Calls `compute()` to calculate metric for current batch.\n",
    "6. Restores the global state.\n",
    "\n",
    "> This procedure has the consequence of calling the user defined `update` twice during a single `forward` call (one to update global statistics and one for getting the batch statistics)."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0e01b166-2f88-44b8-ad7b-4db0a753a77d",
   "metadata": {},
   "source": [
    "## Bonus: Computing the running average"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "02ed3b65-a830-4bbf-a09a-b3064ec3394b",
   "metadata": {},
   "source": [
    "The fact that there are separate `.update()` and `.compute()` methods allow us to compute the running average of a metric (rather then the metric on the single batch). To achieve this, we can include a `.compute()` call inside the loop:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "b7930154-5c7c-4f84-b945-16938ac8460f",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Running accuracy: tensor(0.7000)\n",
      "Running accuracy: tensor(0.7000)\n",
      "Running accuracy: tensor(0.6333)\n",
      "Running accuracy: tensor(0.6250)\n",
      "Running accuracy: tensor(0.5800)\n",
      "Running accuracy: tensor(0.5500)\n",
      "Running accuracy: tensor(0.5571)\n",
      "Running accuracy: tensor(0.5625)\n",
      "Running accuracy: tensor(0.5556)\n",
      "Running accuracy: tensor(0.5600)\n",
      "Overall accuracy: tensor(0.5600)\n"
     ]
    }
   ],
   "source": [
    "from torchmetrics.classification import Accuracy\n",
    "\n",
    "\n",
    "train_acc = Accuracy()\n",
    "torch.manual_seed(123)\n",
    "\n",
    "for i in range(10):\n",
    "    \n",
    "    y_true = torch.randint(low=0, high=2, size=(10,))\n",
    "    y_pred = torch.randint(low=0, high=2, size=(10,))\n",
    "    \n",
    "    train_acc.update(y_true, y_pred)\n",
    "    abc = train_acc.compute()\n",
    "    print('Running accuracy:', abc)\n",
    "    \n",
    "print('Overall accuracy:', train_acc.compute())"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8a9c732e-86cd-4919-a32b-d72ec9fff66b",
   "metadata": {},
   "source": [
    "The preciding code is equivalent to the following manual computation:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "20f6a7af-9604-4515-9d95-5825969fae90",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Running accuracy: tensor(0.7000)\n",
      "Running accuracy: tensor(0.7000)\n",
      "Running accuracy: tensor(0.6333)\n",
      "Running accuracy: tensor(0.6250)\n",
      "Running accuracy: tensor(0.5800)\n",
      "Running accuracy: tensor(0.5500)\n",
      "Running accuracy: tensor(0.5571)\n",
      "Running accuracy: tensor(0.5625)\n",
      "Running accuracy: tensor(0.5556)\n",
      "Running accuracy: tensor(0.5600)\n",
      "Overall accuracy: tensor(0.5600)\n"
     ]
    }
   ],
   "source": [
    "torch.manual_seed(123)\n",
    "\n",
    "num = 0\n",
    "correct = 0.\n",
    "\n",
    "for i in range(10):\n",
    "    y_true = torch.randint(low=0, high=2, size=(10,))\n",
    "    y_pred = torch.randint(low=0, high=2, size=(10,))\n",
    "    \n",
    "    correct_batch = (y_true == y_pred).float().sum()\n",
    "    correct += correct_batch\n",
    "    num += y_true.numel()\n",
    "    \n",
    "    abc = correct / num\n",
    "    print('Running accuracy:', abc) \n",
    "    \n",
    "acc = correct / num\n",
    "print('Overall accuracy:', acc)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
